import math
import random

import networkx as nx
import torch

from tools.utils import preprocess, dim_action_space


class BoltzmannBandit:
    def __init__(self, gwr, args, action_space, save_dir, context= None):
        self.cpt_maj = -1
        self.dist_weights = None
        self.gwr = gwr
        self.save_dir = save_dir
        self.actor_critic = self
        self.action_space = action_space
        self.action_size = dim_action_space(self.action_space)
        self.args = args
        self.eval_mode = False
        self.context = context
        self.dist = torch.distributions.Normal(torch.zeros(1, self.action_size),torch.zeros(1, self.action_size) + self.args.a_threshold)

    def check_invalid(self, buffer):
        return not buffer.learnDataStore.available() or (buffer.cpt_select_indexes == 0)

    def argmax(self, iterable):
        return max(enumerate(iterable), key=lambda x: x[1])[0]

    def uniform_random_unit(self):
        key = self.gwr.buffers.random_key()
        while self.check_invalid(self.gwr.buffers[key]):
            key = self.gwr.buffers.random_key()
        return key

    def generate_all(self):
        return self.gwr.generate_all()

    def random_unit(self):
        rew_list = self.gwr.exec_reward_list
        # if self.eval_mode:
        #     return self.argmax(rew_list)

        self.cpt_maj += 1
        exploration_weights = [(math.log(self.gwr.buffers[k].cpt_select_indexes + 1)) for k in self.gwr.buffers]
        weights = [(self.args.skew_select) * w + self.args.tau_coord * r for w, r in zip(exploration_weights, rew_list)]
        weights = [weights[k] if not self.check_invalid(self.gwr.buffers[k]) else -10000 for k in self.gwr.buffers]
        self.weights = torch.tensor(weights)
        self.dist_weights = torch.distributions.Categorical(logits=self.weights)
        sample = self.dist_weights.sample().item()
        while self.check_invalid(self.gwr.buffers[sample]):
            sample = self.dist_weights.sample().item()
        return sample

    def evaluate(self, returns, goal, obs,final_state=None,act_state=None):
        for i in range(returns.shape[0]):
            true_skill = goal
            true_key = self.gwr.find_nearest_units(true_skill if act_state is None else act_state)[0]
            self.gwr.exec_reward_list[true_key] = self.gwr.exec_reward_list[true_key] * (
                        1 - self.args.lr_coord) + self.args.lr_coord * returns[i].item()
            if self.args.rew_coord_type == 5:
                for neigh in self.gwr.neighbors(true_key):
                    self.gwr.exec_reward_list[neigh] = self.gwr.exec_reward_list[neigh] * (
                                1 - self.args.lr_coord2) + (self.args.lr_coord2) * returns[i].item()
            if self.args.rew_coord_type == 6 or self.args.rew_coord_type == 7 or self.args.rew_coord_type == 8:
                fake_skill = self.context.estimator.label_embed(preprocess(obs,self.args),act=True).cpu()
                fake_key = self.gwr.find_nearest_units(fake_skill if final_state is None else final_state)[0]
                if self.args.rew_coord_type == 6 or self.args.rew_coord_type == 8:
                    self.gwr.exec_reward_list[fake_key] = self.gwr.exec_reward_list[fake_key] * (
                            1 - self.args.lr_coord2) + (self.args.lr_coord2) * returns[i].item()

                if self.args.rew_coord_type == 7 or self.args.rew_coord_type == 8:
                    for neigh in self.gwr.neighbors(fake_key):
                        self.gwr.exec_reward_list[neigh] = self.gwr.exec_reward_list[neigh] * (
                                1 - self.args.lr_coord2) + (self.args.lr_coord2) * returns[i].item()
        return

    def load(self):
        pass

    def save(self):
        pass


class BanditObs(BoltzmannBandit):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.goals_obs = None
        self.act_state=None
        self.final_act_state=None

    def check_invalid(self, buffer):
        return not buffer.learnDataStore.available() or (buffer.cpt_select_indexes == 0)

    def spherical_sample(self, num_dim, noise, numbers=1, surface=False):
        # http://extremelearning.com.au/how-to-generate-uniformly-random-points-on-n-spheres-and-n-balls/
        # Muller method
        # noise must be powered, to get between 0 and 0.5, noise should be 0.5**2
        # https://reader.elsevier.com/reader/sd/pii/S0047259X10001211?token=B6780EA3B3D022ADC0AC819F98DCE674F6FCDED0218867A5289428DBADE521CE74D6AE736557F2725EFCE9366E67E869&originRegion=eu-west-1&originCreation=20210516113723
        normal_dist = torch.distributions.Normal(torch.zeros(numbers, num_dim), torch.ones(numbers, num_dim))
        orientation = normal_dist.sample()
        norm = torch.norm(orientation, 2, dim=1, keepdim=True)
        if surface:
            return orientation / (norm + 0.0001)
        r = torch.pow(torch.distributions.Uniform(0, noise**num_dim).sample((numbers, 1)), 1. / num_dim)
        return r * orientation / (norm + 0.0001)

    def sample(self, numbers=1):
        return self.spherical_sample(self.args.num_latents if not self.args.state else 2, self.args.a_threshold, numbers=numbers,surface=self.args.surface)

    def act(self, inputs, *args, imposed_cluster=None, predefined_goal=None,state=None,predefined_state=None,extend_rep=None,**kwargs):
        if predefined_goal is not None:
            return predefined_goal, None
        clusters = torch.empty(inputs.shape[0], 1, dtype=torch.long)
        if self.goals_obs is None or self.goals_obs.shape[0] != inputs.shape[0]:
            self.goals_obs = torch.empty(inputs.shape, dtype=inputs.dtype)

        ###No goal is available so we take the input as goal
        if self.gwr.insertions == 0:
            self.goals_obs[:] = inputs
            actions = self.gwr.context.estimator.goal_embed(preprocess(inputs, self.args), act=True).cpu()
            clusters = self.gwr.find_nearest_units_all(actions if state is None else state)[1]
            if self.args.state:
                self.act_state=state
            return actions.view(-1, self.args.num_latents), clusters

        for i in range(inputs.shape[0]):
            ###Cluster selection
            if imposed_cluster is not None:
                clusters[i] = imposed_cluster[i]
            elif not self.eval_mode and (random.random() < self.args.epsilon or self.gwr.network.number_of_nodes() == 1):
                clusters[i] = self.uniform_random_unit()
            else:
                clusters[i] = self.random_unit()

            ###Spherical sampling
            key = clusters[i].item()
            buffer = self.gwr.buffers[key]
            v = self.gwr.get(key)
            noise=self.sample()
            noisy_v = noise + v

            ###Rejection sampling
            rejects = 1
            while not self.gwr.is_nearest_unit_among_neighbors(key, noisy_v):
                rejects += 1
                noise = self.sample()
                noisy_v = noise + v
                if rejects > 10:
                    break
            ###Get closest member of the cluster
            embeddings = buffer.get_embeds_samples()
            if embeddings.shape[0] == 0 or self.args.nodes_number:
                self.goals_obs[i:i + 1] = buffer.learnDataStore.sample_obs(state=self.args.state)
                if self.args.state:
                    self.act_state = buffer.learnDataStore.tmp_state
            else:
                nearest = torch.argmin(torch.norm(embeddings - noisy_v.view(1, -1), 2, dim=1), dim=0)
                self.goals_obs[i:i + 1] = buffer.get_obs_to_embed(nearest.item(),state=self.args.state)
                if self.args.state:
                    self.act_state = buffer.tmp_state
        with torch.no_grad():
            actions = self.gwr.context.estimator.goal_embed(preprocess(self.goals_obs, self.args), act=True).cpu()
            if extend_rep is not None:
                extend_rep = self.context.estimator.label_embed(preprocess(inputs, self.args), act=True).cpu() if state is not None else extend_rep
                actions = actions + torch.rand(1,)*(actions-extend_rep)/torch.norm(actions-extend_rep,2)
        if not self.eval_mode and imposed_cluster is None:
            for i in range(inputs.shape[0]):
                self.gwr.update_properties(clusters[i].item(), vectors=actions[i] if state is None else self.act_state)
        return actions.view(-1, self.args.num_latents), clusters


class DijkstraBandit(BanditObs):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.final_goal, self.final_cluster, self.goal, self.target_cluster = None, None, None, None
        self.path = None

    def act(self, inputs, *args, step=-1, predefined_goal=None, predefined_state=None, state=None, **kwargs):

        if step == 0 and predefined_goal is not None:
            self.final_goal = predefined_goal.clone().cpu()
            self.final_cluster = torch.tensor(self.gwr.find_nearest_units(self.final_goal if not self.args.state else predefined_state)[0]).view(1,-1)
            self.final_act_state=predefined_state
            self.executed_path=[]
        elif step == 0: #or super().check_invalid(self.gwr.buffers[self.target_cluster.item()]):
            self.final_goal, self.final_cluster = super().act(inputs,state=state)
            self.final_goal = self.final_goal.cpu()
            self.final_act_state=self.act_state
            self.executed_path=[]

        actual_rep = self.context.estimator.label_embed(preprocess(inputs, self.args), act=True).cpu() if state is None else state
        actual_cluster = self.gwr.find_nearest_units(actual_rep)[0]

        compareg = self.goal if state is None else self.act_state
        if step == 0 or (torch.norm(actual_rep - compareg, 2, dim=1) < self.args.delta_reach
                         and (self.args.plan_steps != 1 or actual_cluster == self.target_cluster)):
            #path = self.gwr.shortest_paths[actual_cluster][self.final_cluster.item()]

            if not self.executed_path or actual_cluster != self.executed_path[-1]:
                self.executed_path.append(actual_cluster)
            path_exist = False
            while not path_exist:
                try:
                    self.path = nx.shortest_path(self.gwr.network, actual_cluster, self.final_cluster.item())
                    path_exist = True
                except:
                    if self.gwr.available_nodes[self.final_cluster.item()]:
                        self.gwr.delete_node(self.final_cluster.item())
                    self.final_goal, self.final_cluster = super().act(inputs,state=state)
                    self.final_goal = self.final_goal.cpu()
                    self.final_act_state = self.act_state
                    self.executed_path = []
            if len(self.path) <= self.args.plan_steps or self.path[self.args.plan_steps] == self.final_cluster:
                if self.args.plan_version == 1:
                    self.goal, self.target_cluster = self.final_goal, self.final_cluster
                    self.act_state=self.final_act_state
                elif self.args.plan_version >= 2 and step%self.args.v2_duration == 0 and predefined_goal is None:
                    if self.args.plan_version == 3:
                        self.goal, self.target_cluster = super().act(inputs,imposed_cluster=self.final_cluster,state=state,extend_rep=actual_rep)
                    else:
                        self.goal, self.target_cluster = super().act(inputs,imposed_cluster=self.final_cluster,state=state)

                    self.final_act_state = self.act_state
            else:
                self.goal, self.target_cluster = super().act(inputs, imposed_cluster=torch.tensor(self.path[self.args.plan_steps]).view(1, -1),state=state)

        return self.goal.view(-1, self.args.num_latents), self.target_cluster

    def load(self):
        pass

    def save(self):
        pass
